import logging
import os
import sys
import warnings
import time
from datetime import datetime
from typing import Optional
import random


class LogWrapper:
    _instance = None
    _opened = False

    def __new__(cls, *args, **kwargs):
        if not isinstance(cls._instance, cls):
            cls._instance = object.__new__(cls)
        return cls._instance

    def __init__(self, debug: bool, log_dir: str, calling_file: Optional[str]):
        if LogWrapper._opened:
            return
        rand_id = int(random.random() * time.time_ns())
        if calling_file is None:
            full_name = os.path.join(log_dir, f'{datetime.now().strftime("%m_%d-%H_%M_%S")}_{rand_id}.txt')
        else:
            full_name = os.path.join(log_dir, f'{calling_file}_{datetime.now().strftime("%m_%d-%H_%M_%S")}_{rand_id}.txt')
        if "PYCHARM_HOSTED" in os.environ:
            full_name = full_name.replace('.txt', '_PYCHARM_call.txt')

        logging.basicConfig(filename=full_name, filemode='a', datefmt='%H:%M:%S', level=logging.DEBUG,
                            format='%(asctime)s %(name)s - %(levelname)s:   %(message)s')
        self._debug = debug
        # print(f'\n\nLogging to: {logging.getLogger().handlers[0].baseFilename}\n\n')
        getattr(sys.stdout, 'flush', lambda: None)()
        getattr(sys.stderr, 'flush', lambda: None)()
        LogWrapper._opened = True

    @staticmethod
    def my_print(caller, txt, end='\n'):
        print(caller + ' - ' + txt, end=end)

    def log(self, caller, *args):
        if self._debug:
            txt = ' '.join([str(x) for x in args])
            LogWrapper.my_print(caller, txt)
            logging.getLogger(caller).debug(txt)
            logging.getLogger().handlers[0].flush()

    @staticmethod
    def force_log_and_print(caller, *args, end='\n'):
        txt = ' '.join([str(x) for x in args])
        LogWrapper.my_print(caller, txt, end=end)
        logging.getLogger(caller).debug(txt)
        logging.getLogger().handlers[0].flush()

    def print_and_log(self, caller, *args):
        txt = ' '.join([str(x) for x in args])
        if self._debug:
            logging.getLogger(caller).debug(txt)
            logging.getLogger().handlers[0].flush()
        LogWrapper.my_print(caller, txt)

    @staticmethod
    def warning(caller, *args):
        txt = ' '.join([str(x) for x in args])
        warnings.warn(f'{caller} - {txt}')
        logging.getLogger(caller).warning(txt)
        logging.getLogger().handlers[0].flush()

    @staticmethod
    def error(caller, err, *args):
        txt = ' '.join([str(x) for x in args])
        print('!'*10, '\tERROR\t', '!'*10)
        LogWrapper.my_print(caller, txt)
        logging.getLogger(caller).error(txt)
        logging.getLogger().handlers[0].flush()
        if err is not None:
            raise err(txt)
